# import sys
# sys.path.append(r"F:\\0-code\\pinnwor")
# print(sys.path)

import os
import json
import time
from pathlib import Path

import numpy as np
import sys

sys.path.append(str(Path(__file__).resolve().parents[2]))

from mindspore.common import *
from mindspore.common.initializer import *
from mindspore import context, Tensor, nn
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.profiler import Profiler

from pinn.architecture import MultiScaleFCCell
from pinn.loss import Constraints
from pinn.solver import Solver, LossAndTimeMonitor

from src.dataset import get_test_data, create_random_dataset
from src.callback import TlossCallback
from src.Schrodinger import Schrodinger
from src.architecture import Schrodinger_Net
# from pinn.loss import Schrodinger

set_seed(123456)
np.random.seed(123456)


def train(config):
    # Static Graph
    context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target=config["device_target"],
                        device_id=config["device_id"], save_graphs_path="./graph")

    """training process"""
    # dataset
    elec_train_dataset = create_random_dataset(config)
    train_dataset = elec_train_dataset.create_dataset(batch_size=config["batch_size"],
                                                      shuffle=True,
                                                      drop_remainder=True)

    steps_per_epoch = len(elec_train_dataset)
    print("check train dataset size: ", len(elec_train_dataset))

    model = Schrodinger_Net()

    print("num_losses=", elec_train_dataset.num_dataset)

    # define problem
    train_prob = {}
    for dataset in elec_train_dataset.all_datasets:
        train_prob[dataset.name] = Schrodinger(model=model,
                                               domain_name=dataset.name + "_points",
                                               bc_name=dataset.name + "_points",
                                               ic_name=dataset.name + "_points")
    print("check problem: ", train_prob)
    train_constraints = Constraints(elec_train_dataset, train_prob)

    # optimizer
    params = model.trainable_params()
    opt = nn.Momentum(params, learning_rate=config["lr"], momentum=0.9, weight_decay=0.0)
    optim = nn.LARS(opt, epsilon=1e-08, coefficient=0.02)

    if config["train_process_first"]:
        optim = nn.Adam(params, learning_rate=config["lr"])

    if config["load_ckpt"]:
        param_dict = load_checkpoint(config["load_ckpt_path"])
        load_param_into_net(model, param_dict)

    # define solver
    solver = Solver(model,
                    optimizer=optim,
                    mode="PINNs",
                    train_constraints=train_constraints,
                    test_constraints=None,
                    amp_level='O3'
                    )
    print("steps_per_epoch=", steps_per_epoch)
    loss_time_callback = LossAndTimeMonitor(steps_per_epoch)
    callbacks = [loss_time_callback]
    if config.get("train_with_eval", False):
        inputs, label = get_test_data(config["test_data_path"])
        predict_callback = TlossCallback(model, inputs, label)
        callbacks += [predict_callback]
    if config["save_ckpt"]:
        config_ck = CheckpointConfig(save_checkpoint_steps=config["save_checkpoint_steps"],
                                     keep_checkpoint_max=2)
        ckpoint_cb = ModelCheckpoint(prefix='ckpt_schordinger',
                                     directory=config["save_ckpt_path"], config=config_ck)
        callbacks += [ckpoint_cb]
    print("callbacks=", callbacks)
    solver.train(config["train_epoch"], train_dataset, callbacks=callbacks)


if __name__ == '__main__':
    print("pid:", os.getpid())
    # profiler = Profiler(output_path='./profiler_data')
    configs = json.load(open("./config.json"))
    print("check config: {}".format(configs))
    time_beg = time.time()
    train(configs)
    # profiler.analyse()
    print("End-to-End total time: {} s".format(time.time() - time_beg))
